{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "9f0d0f32-23b4-41a6-b364-579da297c326"
      },
      "outputs": [],
      "source": [
        "# @title Copyright & License (click to expand)\n",
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dd53d60c-97eb-4c72-91ea-f274a753ab34"
      },
      "source": [
        "# Supervised Fine Tuning with Gemini 2.5 Flash for Image Captioning\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Ftuning%2Fsft_gemini_on_image_captioning.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>    \n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\">\n",
        "      <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/tuning/sft_gemini_on_image_captioning.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MgVK7IeKpW27"
      },
      "source": [
        "| Author(s) |\n",
        "| --- |\n",
        "| [Deepak Moonat](https://github.com/dmoonat) \n",
        "| [Erwin Huizenga](https://github.com/erwinh85)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ef820fb-1203-4cab-965f-17093a4ba25e"
      },
      "source": [
        "## Overview\n",
        "\n",
        "**Gemini** is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases. The Gemini API gives you access to the various Gemini models, such as Gemini 2.0 Pro/Flash, Gemini 2.0/Flash, Gemini/Flash and more.\n",
        "\n",
        "This notebook demonstrates how to fine-tune the Gemini 2.5 Flash generative model using the Vertex AI Supervised Tuning feature. Supervised Tuning allows you to use your own training data to further refine the base model's capabilities towards your specific tasks.\n",
        "\n",
        "\n",
        "Supervised Tuning uses labeled examples to tune a model. Each example demonstrates the output you want from your text model during inference.\n",
        "\n",
        "First, ensure your training data is of high quality, well-labeled, and directly relevant to the target task. This is crucial as low-quality data can adversely affect the performance and introduce bias in the fine-tuned model.\n",
        "- Training: Experiment with different configurations to optimize the model's performance on the target task.\n",
        "- Evaluation:\n",
        "  - Metric: Choose appropriate evaluation metrics that accurately reflect the success of the fine-tuned model for your specific task\n",
        "  - Evaluation Set: Use a separate set of data to evaluate the model's performance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "74b00940-376c-4056-90fb-d22c1ce6eedf"
      },
      "source": [
        "### Objective\n",
        "\n",
        "In this tutorial, you will learn how to use `Vertex AI` to tune a `Gemini 2.5 Flash` model.\n",
        "\n",
        "\n",
        "This tutorial uses the following Google Cloud ML services:\n",
        "\n",
        "- `Vertex AI`\n",
        "\n",
        "\n",
        "The steps performed include:\n",
        "\n",
        "- Prepare and load the dataset\n",
        "- Load the `gemini-2.5-flash` model\n",
        "- Evaluate the model before tuning\n",
        "- Tune the model.\n",
        "  - This will automatically create a Vertex AI endpoint and deploy the model to it\n",
        "- Evaluate the model after tuning\n",
        "- Make a prediction using tuned model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X0xdTMs10K7y"
      },
      "source": [
        "### Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jCMczwd00N9T"
      },
      "source": [
        "Dataset used in this notebook is about image captioning. [Reference](https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#download_the_model_checkpoint)\n",
        "\n",
        "```\n",
        "Licensed under the Creative Commons Attribution 4.0 License\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6d7b5435-e947-49bb-9ce3-aa8a42c30118"
      },
      "source": [
        "### Costs\n",
        "\n",
        "This tutorial uses billable components of Google Cloud:\n",
        "\n",
        "* Vertex AI\n",
        "* Cloud Storage\n",
        "\n",
        "Learn about [Vertex AI\n",
        "pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage\n",
        "pricing](https://cloud.google.com/storage/pricing), and use the [Pricing\n",
        "Calculator](https://cloud.google.com/products/calculator/)\n",
        "to generate a cost estimate based on your projected usage."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0cbf01f0-5f6e-4bcd-903f-84ccaad5332c"
      },
      "source": [
        "### Install Gen AI SDK and other required packages"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b8e4d4521362"
      },
      "source": [
        "The new Google Gen AI SDK provides a unified interface to Gemini through both the Gemini Developer API and the Gemini API on Vertex AI. With a few exceptions, code that runs on one platform will run on both. This means that you can prototype an application using the Developer API and then migrate the application to Vertex AI without rewriting your code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MpDAgOsK6kZn"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-genai google-cloud-aiplatform jsonlines rouge_score"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Moror1y0Qq2z"
      },
      "source": [
        "### Restart runtime (Colab only)\n",
        "\n",
        "To use the newly installed packages, you must restart the runtime on Google Colab."
      ]
    },

    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dpSnJTbIrFsh"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b37d4259-7e39-417b-8879-24f7575732c8"
      },
      "source": [
        "## Before you begin\n",
        "\n",
        "### Set your project ID\n",
        "\n",
        "**If you don't know your project ID**, try the following:\n",
        "* Run `gcloud config list`.\n",
        "* Run `gcloud projects list`.\n",
        "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "caaf0d7e-c6cb-4e56-af5c-553db5180e00"
      },
      "outputs": [],
      "source": [
        "PROJECT_ID = \"[YOUR_PROJECT_ID]\"  # @param {type:\"string\"}\n",
        "# Set the project id\n",
        "! gcloud config set project {PROJECT_ID}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "054d794d-cd2e-4280-95ac-859b264ea2d6"
      },
      "source": [
        "#### Region\n",
        "\n",
        "You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0121bf60-1acd-4272-afaf-aa54b4ded263"
      },
      "outputs": [],
      "source": [
        "REGION = \"us-central1\"  # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "czjH2JfKaGfH"
      },
      "source": [
        "#### Bucket\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c_iZzYtraF3y"
      },
      "outputs": [],
      "source": [
        "BUCKET_NAME = \"[YOUR_BUCKET_NAME]\"  # @param {type:\"string\"}\n",
        "BUCKET_URI = f\"gs://{BUCKET_NAME}\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eac9e842-d225-4876-836f-afdb1937d800"
      },
      "source": [
        "### Authenticate your Google Cloud account\n",
        "\n",
        "Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.\n",
        "\n",
        "**1. Vertex AI Workbench**\n",
        "* Do nothing as you are already authenticated.\n",
        "\n",
        "**2. Local JupyterLab instance, uncomment and run:**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "23082eec-b1bd-4594-b5b5-56fe2b74db6f"
      },
      "outputs": [],
      "source": [
        "# ! gcloud auth login"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3c20f923-3c46-4d6d-80d2-d7cb22b1a8da"
      },
      "source": [
        "**3. Authenticate your notebook environment**\n",
        "\n",
        "If you are running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "60302a3f-fad9-452c-8998-a9c9822d2732"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "\n",
        "auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ac33116d-b079-46cb-9614-86326c211e00"
      },
      "source": [
        "**4. Service account or other**\n",
        "* See how to grant Cloud Storage permissions to your service account at https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e6a924d0-a034-4e53-b240-03d356c7b7a6"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "463729ba-ec3c-4302-95bf-80207b0f9e2d"
      },
      "outputs": [],
      "source": [
        "import io\n",
        "import time\n",
        "\n",
        "# For visualization.\n",
        "from PIL import Image\n",
        "from google import genai\n",
        "\n",
        "# For Google Cloud Storage service.\n",
        "from google.cloud import storage\n",
        "\n",
        "# For fine tuning Gemini model.\n",
        "import google.cloud.aiplatform as aiplatform\n",
        "from google.genai import types\n",
        "\n",
        "# For data handling.\n",
        "import jsonlines\n",
        "import pandas as pd\n",
        "\n",
        "# For evaluation.\n",
        "from rouge_score import rouge_scorer\n",
        "from tqdm import tqdm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a522acfe-d0b6-4b4e-b201-0a4ccf59b133"
      },
      "source": [
        "## Initialize Vertex AI and Gen AI SDK for python"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c845aca6-4f72-4d3b-b9ed-de4a18fcbbf8"
      },
      "outputs": [],
      "source": [
        "aiplatform.init(project=PROJECT_ID, location=REGION)\n",
        "\n",
        "client = genai.Client(vertexai=True, project=PROJECT_ID, location=REGION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "okht6CExcw4d"
      },
      "source": [
        "## Prepare Multimodal Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8N1QCz0MzyD6"
      },
      "source": [
        "The dataset used to tune a foundation model needs to include examples that align with the task that you want the model to perform."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9yp9SQ1M7FSP"
      },
      "source": [
        "Note:\n",
        "- Only support images and text as input, and text only as output.\n",
        "- Maximum 30 Images per tuning example.\n",
        "- Maximum image file size: 20MB\n",
        "- Image needs to be in `jpeg` or `png` format. Supported mimetypes: `image/jpeg` and `image/png`\n",
        "\n",
        "Input is a jsonl file with each json string being on one line.\n",
        "Each json instance have the format (Expanded for clarity):\n",
        "```\n",
        "{\n",
        "   \"contents\":[\n",
        "      {\n",
        "         \"role\":\"user\",  # This indicate input content\n",
        "         \"parts\":[ # Interleaved image and text, could be in any order.\n",
        "            {\n",
        "               \"fileData\":{ # FileData needs to be reference to image file in gcs. No inline data.\n",
        "                  \"mimeType\":\"image/jpeg\", # Provide the mimeType about this image\n",
        "                  \"fileUri\":\"gs://path/to/image_uri\"\n",
        "               }\n",
        "            }\n",
        "            {\n",
        "               \"text\":\"What is in this image?\"\n",
        "            }\n",
        "         ]\n",
        "      },\n",
        "      {\n",
        "         \"role\":\"model\", # This indicate target content\n",
        "         \"parts\":[ # text only\n",
        "            {\n",
        "               \"text\":\"Something about this image.\"\n",
        "            }\n",
        "         ]\n",
        "      } # Single turn input and response.\n",
        "   ]\n",
        "}\n",
        "```\n",
        "\n",
        "Example:\n",
        "```\n",
        "{\n",
        "   \"contents\":[\n",
        "      {\n",
        "         \"role\":\"user\",\n",
        "         \"parts\":[\n",
        "            {\n",
        "               \"fileData\":{\n",
        "                  \"mimeType\":\"image/jpeg\",\n",
        "                  \"fileUri\":\"gs://bucketname/data/vision_data/task/image_description/image/1.jpeg\"\n",
        "               }\n",
        "            },\n",
        "            {\n",
        "               \"text\":\"Describe this image that captures the essence of it.\"\n",
        "            }\n",
        "         ]\n",
        "      },\n",
        "      {\n",
        "         \"role\":\"model\",\n",
        "         \"parts\":[\n",
        "            {\n",
        "               \"text\":\"A person wearing a pink shirt and a long-sleeved shirt with a large cuff, ....\"\n",
        "            }\n",
        "         ]\n",
        "      }\n",
        "   ]\n",
        "}\n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DESw8v4QrLHR"
      },
      "source": [
        "### Data files\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uiTVJqMXTvM5"
      },
      "source": [
        "Data used in this notebook is present in the public Cloud Storage(GCS) bucket, `gs://longcap100`.\n",
        "\n",
        "Sample:\n",
        "\n",
        "> {\"prefix\": \"\", \"suffix\": \"A person wearing a pink shirt and a long-sleeved shirt with a large cuff, has their hand on a concrete ledge. The hand is on the edge of the ledge, and the thumb is on the edge of the hand. The shirt has a large cuff, and the sleeve is rolled up. The shadow of the hand is on the wall.\", \"image\": \"91.jpeg\"}\n",
        "\n",
        "\n",
        "\n",
        "- `data_train90.jsonl`: Contains training samples in json lines as shown above\n",
        "- `data_val10.jsonl`: Contains validation samples in json lines as shown above\n",
        "- `images`: Contains 100 images, training and validation data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MLcuIXlzz36C"
      },
      "source": [
        "To run a tuning job, you need to upload one or more datasets to a Cloud Storage bucket. You can either create a new Cloud Storage bucket or use an existing one to store dataset files. The region of the bucket doesn't matter, but we recommend that you use a bucket that's in the same Google Cloud project where you plan to tune your model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sfIUgj-mU8K9"
      },
      "source": [
        "### Create a Cloud Storage bucket"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T_uC6nuFU-XU"
      },
      "source": [
        "- Create a storage bucket to store intermediate artifacts such as datasets.\n",
        "\n",
        "- Only if your bucket doesn't already exist: Run the following cell to create your Cloud Storage bucket.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M-L1BH8TU9Gn"
      },
      "outputs": [],
      "source": [
        "!gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZUGi7ZThbChr"
      },
      "source": [
        "### Copy images to specified Bucket"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DHdC-9nj071o"
      },
      "outputs": [],
      "source": [
        "!gsutil -m -q cp -n -r gs://longcap100/*.jpeg {BUCKET_URI}/images/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fpyJR6tlVRXh"
      },
      "source": [
        "- Download the training and validation dataset jsonlines files from the bucket."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "peUixIt_2DLP"
      },
      "outputs": [],
      "source": [
        "!gsutil -m -q cp -n -r gs://longcap100/data_train90.jsonl ."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rtXMRqAi1WiF"
      },
      "outputs": [],
      "source": [
        "!gsutil -m -q cp -n -r gs://longcap100/data_val10.jsonl ."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a9N-rN7pECKa"
      },
      "source": [
        "### Prepare dataset for Training and Evaluation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEfGLRVfsrii"
      },
      "source": [
        "- Utility function to save json instances into jsonlines format"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zdVGCwFWsrCB"
      },
      "outputs": [],
      "source": [
        "def save_jsonlines(file, instances):\n",
        "    \"\"\"\n",
        "    Saves a list of json instances to a jsonlines file.\n",
        "    \"\"\"\n",
        "    with jsonlines.open(file, mode=\"w\") as writer:\n",
        "        writer.write_all(instances)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-hMIYgYBsbUt"
      },
      "source": [
        "- Below function converts the dataset into Gemini-1.5 tuning format"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0TFcj_tjaALV"
      },
      "outputs": [],
      "source": [
        "task_prompt = \"Describe this image in detail that captures the essence of it.\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LZ1cauVkz8Vv"
      },
      "outputs": [],
      "source": [
        "def create_tuning_samples(file_path):\n",
        "    \"\"\"\n",
        "    Creates tuning samples from a file.\n",
        "    \"\"\"\n",
        "    with jsonlines.open(file_path) as reader:\n",
        "        instances = []\n",
        "        for obj in reader:\n",
        "            instance = {\n",
        "                \"contents\": [\n",
        "                    {\n",
        "                        \"role\": \"user\",  # This indicate input content\n",
        "                        \"parts\": [  # Interleaved image and text, could be in any order.\n",
        "                            {\n",
        "                                \"fileData\": {  # FileData needs to be reference to image file in gcs. No inline data.\n",
        "                                    \"mimeType\": \"image/jpeg\",  # Provide the mimeType about this image\n",
        "                                    \"fileUri\": f\"{BUCKET_URI}/images/{obj['image']}\",\n",
        "                                }\n",
        "                            },\n",
        "                            {\"text\": task_prompt},\n",
        "                        ],\n",
        "                    },\n",
        "                    {\n",
        "                        \"role\": \"model\",  # This indicate target content\n",
        "                        \"parts\": [{\"text\": obj[\"suffix\"]}],  # text only\n",
        "                    },  # Single turn input and response.\n",
        "                ]\n",
        "            }\n",
        "            instances.append(instance)\n",
        "    return instances"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tqh6WYHg6X4z"
      },
      "source": [
        "- Training data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b685Iy27z1E1"
      },
      "outputs": [],
      "source": [
        "train_file_path = \"data_train90.jsonl\"\n",
        "train_instances = create_tuning_samples(train_file_path)\n",
        "# save the training instances to jsonl file\n",
        "save_jsonlines(\"train.jsonl\", train_instances)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UC4ULRC46mA-"
      },
      "outputs": [],
      "source": [
        "train_instances[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nyn5Xgw41bhc"
      },
      "outputs": [],
      "source": [
        "# save the training data to GCS bucket\n",
        "!gsutil cp train.jsonl {BUCKET_URI}/train/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HLsC3IBL6ZWk"
      },
      "source": [
        "- Validation data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LIp0hdag6bS0"
      },
      "outputs": [],
      "source": [
        "val_file_path = \"data_val10.jsonl\"\n",
        "val_instances = create_tuning_samples(val_file_path)\n",
        "# save the training instances to jsonl file\n",
        "save_jsonlines(\"val.jsonl\", val_instances)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TBTBTx4n6koL"
      },
      "outputs": [],
      "source": [
        "val_instances[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xy-6ihNR6gx3"
      },
      "outputs": [],
      "source": [
        "# save the validation data to GCS bucket\n",
        "!gsutil cp val.jsonl {BUCKET_URI}/val/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QhejcJumTAj3"
      },
      "source": [
        "- Below code transforms the jsonl format to following structure\n",
        "\n",
        "`\n",
        "[{'file_uri': '<GCS path for query image>',\n",
        " 'ground_truth': '<Ground truth, image description'},\n",
        " ..\n",
        "]\n",
        "`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uEGTSHPx4aMl"
      },
      "outputs": [],
      "source": [
        "data_table = []\n",
        "for instance in val_instances:\n",
        "    data_table.append(\n",
        "        {\n",
        "            \"file_uri\": instance[\"contents\"][0][\"parts\"][0][\"fileData\"][\"fileUri\"],\n",
        "            \"ground_truth\": instance[\"contents\"][1][\"parts\"][0][\"text\"],\n",
        "        }\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JB2lp-0c0TA4"
      },
      "outputs": [],
      "source": [
        "data_table[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MQOXFXUF1u1Y"
      },
      "source": [
        "- The `data_table` is converted into dataframe of two columns, file_uri and ground_truth. The `ground_truth` will be compared with the model generated output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B3oaGk0XMsg3"
      },
      "outputs": [],
      "source": [
        "val_df = pd.DataFrame(data_table)\n",
        "val_df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iUweYYURlDg5"
      },
      "source": [
        "- Total `10` instances in validation data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dw-gQpLXTe67"
      },
      "source": [
        "## Visualization utils"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1h2xl9igTiT7"
      },
      "source": [
        "- Function to visualize the query images stored in GCS bucket"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZTTMVUOhTeur"
      },
      "outputs": [],
      "source": [
        "# read a image bytes file present in GCS bucket\n",
        "\n",
        "\n",
        "def read_image_bytes_from_gcs(bucket_name, blob_name):\n",
        "    \"\"\"Reads image bytes from a GCS bucket.\n",
        "\n",
        "    Args:\n",
        "      bucket_name: The name of the GCS bucket.\n",
        "      blob_name: The name of the blob (file) within the bucket.\n",
        "\n",
        "    Returns:\n",
        "      The image bytes as a bytes object.\n",
        "    \"\"\"\n",
        "\n",
        "    storage_client = storage.Client()\n",
        "    bucket = storage_client.bucket(bucket_name)\n",
        "    blob = bucket.blob(blob_name)\n",
        "\n",
        "    image_bytes = blob.download_as_bytes()\n",
        "\n",
        "    return image_bytes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2tV6nWuD59-H"
      },
      "source": [
        "## Evaluation Pre-Tuning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tn1uCzlT1j42"
      },
      "source": [
        "- Assign `gemini-2.5-flash` as base_model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2DLk7RYYPhi0"
      },
      "outputs": [],
      "source": [
        "base_model = \"gemini-2.5-flash\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "usabcedw0EVT"
      },
      "source": [
        "### Generation config\n",
        "\n",
        "- Each call that you send to a model includes parameter values that control how the model generates a response. The model can generate different results for different parameter values\n",
        "- <strong>Experiment</strong> with different parameter values to get the best values for the task\n",
        "\n",
        "Refer to the following [link](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/adjust-parameter-values) for understanding different parameters"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zUx23W_r0F8z"
      },
      "source": [
        "**Prompt** is a natural language request submitted to a language model to receive a response back\n",
        "\n",
        "Some best practices include\n",
        "  - Clearly communicate what content or information is most important\n",
        "  - Structure the prompt:\n",
        "    - Defining the role if using one. For example, You are an experienced UX designer at a top tech company\n",
        "    - Include context and input data\n",
        "    - Provide the instructions to the model\n",
        "    - Add example(s) if you are using them\n",
        "\n",
        "Refer to the following [link](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/prompt-design-strategies) for prompt design strategies."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uuKHRy2OVX0w"
      },
      "source": [
        "### Task"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U-YD1J3VTSoI"
      },
      "source": [
        "***Task prompt:***\n",
        "\n",
        "`\n",
        "\"<image>, Describe this image that captures the essence of it. \"\n",
        "`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zTZS4IJMTVR1"
      },
      "source": [
        "***Query Image (image)***\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-Ry2IjT2TWwd"
      },
      "outputs": [],
      "source": [
        "query_image_uri = val_instances[0][\"contents\"][0][\"parts\"][0][\"fileData\"][\"fileUri\"]\n",
        "blob_name = query_image_uri.replace(f\"{BUCKET_URI}/\", \"\")\n",
        "img = read_image_bytes_from_gcs(BUCKET_NAME, blob_name)\n",
        "\n",
        "# Display image bytes using pil python library\n",
        "image = Image.open(io.BytesIO(img))\n",
        "resized_img = image.resize((300, 300))\n",
        "display(resized_img)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "04lAlLK53IYS"
      },
      "source": [
        "- Test on single instance"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-MeiP8z-o6qt"
      },
      "outputs": [],
      "source": [
        "response = client.models.generate_content(\n",
        "    model=base_model,\n",
        "    contents=[\n",
        "        types.Part.from_uri(file_uri=str(query_image_uri), mime_type=\"image/jpeg\"),\n",
        "        \"Describe this image that captures the essence of it.\",\n",
        "    ],\n",
        "    # Optional config\n",
        "    config={\n",
        "        \"temperature\": 0.0,\n",
        "    },\n",
        ")\n",
        "\n",
        "print(response.text.strip())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5LISwh5_4R1U"
      },
      "source": [
        "- Ground truth"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aGXbUVK-3lO5"
      },
      "outputs": [],
      "source": [
        "val_instances[0][\"contents\"][1][\"parts\"][0][\"text\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MRVAwGLB6KUX"
      },
      "source": [
        "- Change prompt to get detailed description for the provided image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JO-C5BAVsdfd"
      },
      "outputs": [],
      "source": [
        "response = client.models.generate_content(\n",
        "    model=base_model,\n",
        "    contents=[\n",
        "        types.Part.from_uri(file_uri=str(query_image_uri), mime_type=\"image/jpeg\"),\n",
        "        \"Describe this image in detail that captures the essence of it.\",\n",
        "    ],\n",
        "    # Optional config\n",
        "    config={\n",
        "        \"temperature\": 0.0,\n",
        "    },\n",
        ")\n",
        "\n",
        "print(response.text.strip())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "snYSjdzCVjGA"
      },
      "source": [
        "## Evaluation before model tuning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vVvGqqTSVzUZ"
      },
      "source": [
        "- Evaluate the Gemini model on the validation dataset before tuning it on the training dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "otIRm3XBwQnW"
      },
      "outputs": [],
      "source": [
        "def get_prediction(query_image_uri, base_model):\n",
        "    \"\"\"Gets the prediction for a given instance.\n",
        "\n",
        "    Args:\n",
        "      query_image: The path to the query image.\n",
        "      candidates: A list of paths to the candidate images.\n",
        "      generation_model: The generation model to use for prediction.\n",
        "\n",
        "    Returns:\n",
        "      A string containing the prediction.\n",
        "    \"\"\"\n",
        "    response = client.models.generate_content(\n",
        "        model=base_model,\n",
        "        contents=[\n",
        "            types.Part.from_uri(file_uri=str(query_image_uri), mime_type=\"image/jpeg\"),\n",
        "            task_prompt,\n",
        "        ],\n",
        "        # Optional config\n",
        "        config={\n",
        "            \"temperature\": 0.0,\n",
        "        },\n",
        "    )\n",
        "\n",
        "    return response.text.strip()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rRW5UVau3xfO"
      },
      "outputs": [],
      "source": [
        "def run_eval(val_df, model=base_model):\n",
        "    \"\"\"Runs evaluation on the validation dataset.\n",
        "\n",
        "    Args:\n",
        "      val_df: The validation dataframe.\n",
        "      generation_model: The generation model to use for evaluation.\n",
        "\n",
        "    Returns:\n",
        "      A list of predictions on val_df.\n",
        "    \"\"\"\n",
        "    predictions = []\n",
        "    for i, row in tqdm(val_df.iterrows(), total=val_df.shape[0]):\n",
        "        try:\n",
        "            prediction = get_prediction(row[\"file_uri\"], model)\n",
        "        except:\n",
        "            time.sleep(30)\n",
        "            prediction = get_prediction(row[\"file_uri\"], model)\n",
        "        predictions.append(prediction)\n",
        "        time.sleep(1)\n",
        "    return predictions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "29O4EccbqbIa"
      },
      "source": [
        "- Evaluate the Gemini model on the test dataset before tuning it on the training dataset.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0LunPnr5Tvce"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ It will take ~1 min for the model to generate predictions on the provided validation dataset. ⚠️</b>\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y2Uy75youUor"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "predictions = run_eval(val_df, model=base_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7BOg0EZpgg3D"
      },
      "outputs": [],
      "source": [
        "len(predictions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N22X-_V5mlev"
      },
      "outputs": [],
      "source": [
        "val_df.loc[:, \"basePredictions\"] = predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bzA_YLSQ67Jc"
      },
      "outputs": [],
      "source": [
        "val_df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nbPYwzNVWgz-"
      },
      "source": [
        "### Evaluation metric"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mvqIYHNCWigP"
      },
      "source": [
        "The type of metrics used for evaluation depends on the task that you are evaluating. The following table shows the supported tasks and the metrics used to evaluate each task:\n",
        "\n",
        "| Task             | Metric(s)                     |\n",
        "|-----------------|---------------------------------|\n",
        "| Classification   | Micro-F1, Macro-F1, Per class F1 |\n",
        "| Summarization    | ROUGE-L                         |\n",
        "| Question Answering | Exact Match                     |\n",
        "| Text Generation  | BLEU, ROUGE-L                   |\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BTkLeYDJWre1"
      },
      "source": [
        "For this task, we'll using ROUGE metric.\n",
        "\n",
        "- **Recall-Oriented Understudy for Gisting Evaluation (ROUGE)**: A metric used to evaluate the quality of automatic summaries of text. It works by comparing a generated summary to a set of reference summaries created by humans."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TIlOr8KFWzqt"
      },
      "source": [
        "Now you can take the candidate and reference to evaluate the performance. In this case, ROUGE will give you:\n",
        "\n",
        "- `rouge-1`, which measures unigram overlap\n",
        "- `rouge-2`, which measures bigram overlap\n",
        "- `rouge-l`, which measures the longest common subsequence"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sIVb60EaW2oW"
      },
      "source": [
        "- *Recall vs. Precision*\n",
        "\n",
        "    **Recall**, meaning it prioritizes how much of the information in the reference summaries is captured in the generated summary.\n",
        "\n",
        "    **Precision**, which measures how much of the generated summary is relevant to the original text."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rDwfndw9OAW9"
      },
      "source": [
        "- Initialize `rouge_score` object"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1SEVHIrk69kj"
      },
      "outputs": [],
      "source": [
        "scorer = rouge_scorer.RougeScorer([\"rouge1\", \"rouge2\", \"rougeL\"], use_stemmer=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_X9vv_gMORkr"
      },
      "source": [
        "- Define function to calculate rouge score"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P6C6EkvFOQzW"
      },
      "outputs": [],
      "source": [
        "def get_rouge_score(groundTruth, prediction):\n",
        "    \"\"\"Function to compute rouge score.\n",
        "\n",
        "    Args:\n",
        "      groundTruth: The ground truth text.\n",
        "      prediction: The predicted text.\n",
        "    Returns:\n",
        "      The rouge score.\n",
        "    \"\"\"\n",
        "    scores = scorer.score(target=groundTruth, prediction=prediction)\n",
        "    return scores"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J6qBe-Mbtem_"
      },
      "source": [
        "- Single instance evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BtP0f3GO7zG7"
      },
      "outputs": [],
      "source": [
        "get_rouge_score(val_df.loc[0, \"ground_truth\"], val_df.loc[0, \"basePredictions\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3zl1PpGA9oWE"
      },
      "outputs": [],
      "source": [
        "def calculate_metrics(val_df, prediction_col=\"basePredictions\"):\n",
        "    \"\"\"Function to compute rouge scores for all instances in the validation dataset.\n",
        "    Args:\n",
        "      val_df: The validation dataframe.\n",
        "      prediction_col: The column name of the predictions.\n",
        "    Returns:\n",
        "      A dataframe containing the rouge scores.\n",
        "    \"\"\"\n",
        "    records = []\n",
        "    for row, instance in val_df.iterrows():\n",
        "        scores = get_rouge_score(instance[\"ground_truth\"], instance[prediction_col])\n",
        "        records.append(\n",
        "            {\n",
        "                \"rouge1_precision\": scores.get(\"rouge1\").precision,\n",
        "                \"rouge1_recall\": scores.get(\"rouge1\").recall,\n",
        "                \"rouge1_fmeasure\": scores.get(\"rouge1\").fmeasure,\n",
        "                \"rouge2_precision\": scores.get(\"rouge2\").precision,\n",
        "                \"rouge2_recall\": scores.get(\"rouge2\").recall,\n",
        "                \"rouge2_fmeasure\": scores.get(\"rouge2\").fmeasure,\n",
        "                \"rougeL_precision\": scores.get(\"rougeL\").precision,\n",
        "                \"rougeL_recall\": scores.get(\"rougeL\").recall,\n",
        "                \"rougeL_fmeasure\": scores.get(\"rougeL\").fmeasure,\n",
        "            }\n",
        "        )\n",
        "    metrics = pd.DataFrame(records)\n",
        "    return metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SS2UrB9g8NBt"
      },
      "outputs": [],
      "source": [
        "evaluation_df_stats = calculate_metrics(val_df, prediction_col=\"basePredictions\")\n",
        "evaluation_df_stats"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZEyRYhEBZwy9"
      },
      "outputs": [],
      "source": [
        "print(\"Mean rougeL_precision is\", evaluation_df_stats.rougeL_precision.mean())\n",
        "print(\"Mean rougeL_recall is\", evaluation_df_stats.rougeL_recall.mean())\n",
        "print(\"Mean rougeL_fmeasure is\", evaluation_df_stats.rougeL_fmeasure.mean())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uYAjjpdG_cpP"
      },
      "source": [
        "## Fine-tune the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lQehNcLG_4Nc"
      },
      "source": [
        "You can create a supervised fine-tuning job by using the Google Gen AI SDK for Python.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d1263b90fbc4"
      },
      "source": [
        "When you run a supervised fine-tuning job, the model learns additional parameters that help it encode the necessary information to perform the desired task or learn the desired behavior. These parameters are used during inference. The output of the tuning job is a new model that combines the newly learned parameters with the original model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EyqBRoY5rscI"
      },
      "source": [
        "**Tuning Job parameters**\n",
        "\n",
        "- `source_model`: Specifies the base Gemini model version you want to fine-tune.\n",
        "- `train_dataset`: Path to your training data in JSONL format.\n",
        "\n",
        "\n",
        " *Optional parameters*\n",
        " - `validation_dataset`: If provided, this data is used to evaluate the model during tuning.\n",
        " - `tuned_model_display_name`: Display name for the tuned model.\n",
        "\n",
        " *Hyperparameters*\n",
        " - `epochs`: The number of training epochs to run.\n",
        " - `learning_rate_multiplier`: A value to scale the learning rate during training.\n",
        " - `adapter_size` : Gemini 2.5 Flash supports Adapter length [1, 2, 4, 8], default value is 4.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UJ0gxBeyqO9k"
      },
      "source": [
        "**Note: The default hyperparameter settings are optimized for optimal performance based on rigorous testing and are recommended for initial use. Users may customize these parameters to address specific performance requirements.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "anlX0A5aAIPx"
      },
      "source": [
        "- Check out the [documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-use-supervised-tuning#tuning_hyperparameters) to learn more.\n",
        "- [Gen AI SDK for tuning job](https://googleapis.github.io/python-genai/genai.html#genai.types.CreateTuningJobConfig)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_vbe8o4_8qV6"
      },
      "outputs": [],
      "source": [
        "tuned_model_display_name = \"[DISPLAY NAME FOR TUNED MODEL]\"  # @param {type:\"string\"}\n",
        "\n",
        "training_dataset = {\n",
        "    \"gcs_uri\": f\"{BUCKET_URI}/train/train.jsonl\",\n",
        "}\n",
        "\n",
        "validation_dataset = types.TuningValidationDataset(\n",
        "    gcs_uri=f\"{BUCKET_URI}/val/val.jsonl\"\n",
        ")\n",
        "\n",
        "\n",
        "sft_tuning_job = client.tunings.tune(\n",
        "    base_model=base_model,\n",
        "    training_dataset=training_dataset,\n",
        "    config=types.CreateTuningJobConfig(\n",
        "        adapter_size=\"ADAPTER_SIZE_EIGHT\",\n",
        "        epoch_count=1,  # set to one to keep time and cost low\n",
        "        tuned_model_display_name=tuned_model_display_name,\n",
        "        validation_dataset=validation_dataset,\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hSDpQGUeERcH"
      },
      "outputs": [],
      "source": [
        "# Get the tuning job info.\n",
        "tuning_job = client.tunings.get(name=sft_tuning_job.name)\n",
        "tuning_job"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RE1a3AgRsqJh"
      },
      "source": [
        "**Note: Tuning time depends on several factors, such as training data size, number of epochs, learning rate multiplier, etc.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qHlfSLjKsruX"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ It will take 30-40 mins for the model tuning job to complete on the provided dataset and set configurations/hyperparameters. ⚠️</b>\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "97EUpJwisv_Q"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "# Wait for job completion\n",
        "\n",
        "running_states = [\n",
        "    \"JOB_STATE_PENDING\",\n",
        "    \"JOB_STATE_RUNNING\",\n",
        "]\n",
        "\n",
        "while tuning_job.state.name in running_states:\n",
        "    tuning_job = client.tunings.get(name=sft_tuning_job.name)\n",
        "    time.sleep(10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5680557f-67bd-4e8c-a383-02ab655246c5"
      },
      "source": [
        "## Evaluation Post-tuning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c3d1f75bddea"
      },
      "source": [
        "- Evaluate the Gemini model on the validation dataset with tuned model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bK2Cyrhavw-Y"
      },
      "outputs": [],
      "source": [
        "tuned_model = tuning_job.tuned_model.endpoint\n",
        "tuning_experiment_name = tuning_job.experiment\n",
        "\n",
        "print(\"Tuned model experiment\", tuning_experiment_name)\n",
        "print(\"Tuned model endpoint resource name:\", tuned_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oVEEGZ-cuYx2"
      },
      "source": [
        "- Get a prediction from tuned model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bp4yHwjNJbLQ"
      },
      "outputs": [],
      "source": [
        "response = client.models.generate_content(\n",
        "    model=tuned_model,\n",
        "    contents=[\n",
        "        types.Part.from_uri(file_uri=str(query_image_uri), mime_type=\"image/jpeg\"),\n",
        "        task_prompt,\n",
        "    ],\n",
        "    # Optional config\n",
        "    config={\n",
        "        \"temperature\": 0,\n",
        "    },\n",
        ")\n",
        "\n",
        "print(response.text.strip())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s_1-lbJZugY0"
      },
      "source": [
        "- Evaluate the tuned model on entire validation set"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B7sRtCFCUiag"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ It will take ~1 min for the model to generate predictions on the provided validation dataset. ⚠️</b>\n",
        "</div>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pWxg3i3a391K"
      },
      "outputs": [],
      "source": [
        "%%time\n",
        "predictions_tuned = run_eval(val_df, model=tuned_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V0wJNPEf5-6I"
      },
      "outputs": [],
      "source": [
        "val_df.loc[:, \"tunedPredictions\"] = predictions_tuned"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "og4hVmwCuuPW"
      },
      "outputs": [],
      "source": [
        "evaluation_df_post_tuning_stats = calculate_metrics(\n",
        "    val_df, prediction_col=\"tunedPredictions\"\n",
        ")\n",
        "evaluation_df_post_tuning_stats"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "heKx9Lu5vBYb"
      },
      "source": [
        "- Improvement"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X2AVUCh3S656"
      },
      "outputs": [],
      "source": [
        "evaluation_df_post_tuning_stats.rougeL_precision.mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kTnfegPcvC-P"
      },
      "outputs": [],
      "source": [
        "improvement = round(\n",
        "    (\n",
        "        (\n",
        "            evaluation_df_post_tuning_stats.rougeL_precision.mean()\n",
        "            - evaluation_df_stats.rougeL_precision.mean()\n",
        "        )\n",
        "        / evaluation_df_stats.rougeL_precision.mean()\n",
        "    )\n",
        "    * 100,\n",
        "    2,\n",
        ")\n",
        "print(\n",
        "    f\"Model tuning has improved the rougeL_precision by {improvement}% (result might differ based on each tuning iteration)\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qrs0o6-p6Ebr"
      },
      "outputs": [],
      "source": [
        "# Save predicitons\n",
        "predictions_all = val_df.to_csv(\"validation_pred.csv\", index=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yUuvCQ2O-1OW"
      },
      "source": [
        "## Conclusion"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "me908QT9-26J"
      },
      "source": [
        "Performance could be further improved:\n",
        "- By adding more training samples. In general, improve your training data quality and/or quantity towards getting a more diverse and comprehensive dataset for your task\n",
        "- By tuning the hyperparameters, such as epochs, learning rate multiplier or adapter size\n",
        "  - To find the optimal number of epochs for your dataset, we recommend experimenting with different values. While increasing epochs can lead to better performance, it's important to be mindful of overfitting, especially with smaller datasets. If you see signs of overfitting, reducing the number of epochs can help mitigate the issue\n",
        "- You may try different prompt structures/formats and opt for the one with better performance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F7pq-hvxvy8_"
      },
      "source": [
        "## Cleaning up"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LokkxNS0vzM-"
      },
      "source": [
        "To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud\n",
        "project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.\n",
        "\n",
        "\n",
        "Otherwise, you can delete the individual resources you created in this tutorial.\n",
        "\n",
        "Refer to this [instructions](https://cloud.google.com/vertex-ai/docs/tutorials/image-classification-custom/cleanup#delete_resources) to delete the resources from console."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H38EHjj3vwib"
      },
      "outputs": [],
      "source": [
        "# Delete Experiment.\n",
        "delete_experiments = True\n",
        "if delete_experiments:\n",
        "    experiments_list = aiplatform.Experiment.list()\n",
        "    for experiment in experiments_list:\n",
        "        if experiment.resource_name == tuning_experiment_name:\n",
        "            print(experiment.resource_name)\n",
        "            experiment.delete()\n",
        "            break\n",
        "\n",
        "print(\"***\" * 10)\n",
        "\n",
        "# Delete Endpoint.\n",
        "delete_endpoint = True\n",
        "# If force is set to True, all deployed models on this\n",
        "# Endpoint will be first undeployed.\n",
        "if delete_endpoint:\n",
        "    for endpoint in aiplatform.Endpoint.list():\n",
        "        if endpoint.resource_name == tuned_model:\n",
        "            print(endpoint.resource_name)\n",
        "            endpoint.delete(force=True)\n",
        "            break\n",
        "\n",
        "print(\"***\" * 10)\n",
        "\n",
        "# Delete Cloud Storage Bucket.\n",
        "delete_bucket = True\n",
        "if delete_bucket:\n",
        "    ! gsutil -m rm -r $BUCKET_URI"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "sft_gemini_on_image_captioning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
